#!/usr/bin/env python3

# suppress pkg warning for pymilvus
import argparse
import errno
import hashlib
import itertools
import logging
import os
import os.path
import sys
import threading
import time
import uuid
import warnings
from pathlib import Path

import docling
import qdrant_client
from docling.chunking import HybridChunker
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
from docling.datamodel.base_models import InputFormat
from docling.datamodel.pipeline_options import PdfPipelineOptions
from docling.document_converter import DocumentConverter, PdfFormatOption
from fastembed import SparseTextEmbedding, TextEmbedding
from pymilvus import DataType, MilvusClient
from qdrant_client import models
from transformers import AutoTokenizer

warnings.filterwarnings("ignore", category=UserWarning)

# Global Vars
EMBED_MODEL = os.getenv("EMBED_MODEL", "jinaai/jina-embeddings-v2-small-en")
SPARSE_MODEL = os.getenv("SPARSE_MODEL", "prithivida/Splade_PP_en_v1")
COLLECTION_NAME = "rag"
os.environ["TOKENIZERS_PARALLELISM"] = "true"


class Converter:
    """A Class designed to handle all document conversions using Docling"""

    def __init__(self, args):
        # Docling Setup (Turn off OCR (image processing) for drastically reduced RAM usage and big speed increase)
        if os.environ.get("CUDA_VISIBLE_DEVICES", "").lower() in ["", "none", "-1"]:
            dev = AcceleratorDevice.CPU
        else:
            dev = AcceleratorDevice.CUDA
        pipeline_options = PdfPipelineOptions(
            accelerator_options=AcceleratorOptions(device=dev),
            do_ocr=args.ocr,
        )
        self.sources = []
        for source in args.sources:
            self.add(source)
        self.output = args.output
        self.format = args.format
        self.doc_converter = DocumentConverter(
            format_options={InputFormat.PDF: PdfFormatOption(pipeline_options=pipeline_options)}
        )

    def add(self, file_path):
        if os.path.isdir(file_path):
            self.walk(file_path)  # Walk directory and process all files
        else:
            self.sources.append(file_path)  # Process the single file

    def chunk(self, docs):
        tokenizer = AutoTokenizer.from_pretrained(EMBED_MODEL, local_files_only=True)
        chunker = HybridChunker(tokenizer=tokenizer, max_tokens=8192, overlap=100, merge_peers=True)
        documents, ids = [], []

        for file in docs:
            chunk_iter = chunker.chunk(dl_doc=file.document)
            for chunk in chunk_iter:
                # Extract the enriched text from the chunk
                doc_text = chunker.contextualize(chunk=chunk)
                documents.append(doc_text)
                doc_id = self.generate_hash(doc_text)
                ids.append(doc_id)
        return documents, ids

    def convert_milvus(self, docs):
        output_dir = Path(self.output)
        output_dir.mkdir(parents=True, exist_ok=True)
        milvus_client = MilvusClient(uri=os.path.join(self.output, "milvus.db"))
        collection_name = COLLECTION_NAME
        dmodel = TextEmbedding(model_name=EMBED_MODEL, local_files_only=True)
        smodel = SparseTextEmbedding(model_name=SPARSE_MODEL, local_files_only=True)
        test_embedding = next(dmodel.embed("This is a test"))
        embedding_dim = len(test_embedding)
        schema = MilvusClient.create_schema(
            auto_id=False,
            enable_dynamic_field=True,
        )
        schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
        schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=1000)
        schema.add_field(field_name="sparse", datatype=DataType.SPARSE_FLOAT_VECTOR)
        schema.add_field(field_name="dense", datatype=DataType.FLOAT_VECTOR, dim=embedding_dim)
        index_params = milvus_client.prepare_index_params()
        index_params.add_index(field_name="dense", index_name="dense_index", index_type="AUTOINDEX")
        index_params.add_index(field_name="sparse", index_name="sparse_index", index_type="SPARSE_INVERTED_INDEX")
        milvus_client.create_collection(collection_name=collection_name, schema=schema, index_params=index_params)
        # Chunk and add chunks to collection 1 by 1
        chunks, ids = self.chunk(docs)
        # Batch-embed chunks for better performance
        dense_embeddings = list(dmodel.embed(chunks))
        sparse_embeddings_list = list(smodel.embed(chunks))

        for i, (chunk, id) in enumerate(zip(chunks, ids)):
            sparse_vector = sparse_embeddings_list[i].as_dict()
            milvus_client.insert(
                collection_name=collection_name,
                data=[{"id": id, "text": chunk, "sparse": sparse_vector, "dense": dense_embeddings[i]}],
            )
            # Flush every 100 records to reduce RAM usage
            if i % 100 == 0:
                milvus_client.flush(collection_name=collection_name)
            print(f"\rProcessed chunk {i+1}/{len(chunks)}", end='', flush=True)
        milvus_client.flush(collection_name=collection_name)
        print("\n")
        return

    def convert_qdrant(self, results):
        qclient = qdrant_client.QdrantClient(path=self.output)
        qclient.set_model(EMBED_MODEL, local_files_only=True)
        qclient.set_sparse_model(SPARSE_MODEL, local_files_only=True)
        # optimizations to reduce ram
        qclient.create_collection(
            collection_name=COLLECTION_NAME,
            vectors_config=qclient.get_fastembed_vector_params(on_disk=True),
            sparse_vectors_config=qclient.get_fastembed_sparse_vector_params(on_disk=True),
            quantization_config=models.ScalarQuantization(
                scalar=models.ScalarQuantizationConfig(
                    type=models.ScalarType.INT8,
                    always_ram=True,
                ),
            ),
        )
        chunks, ids = self.chunk(results)
        return qclient.add(COLLECTION_NAME, documents=chunks, ids=ids, batch_size=1)

    def show_progress(self, message, stop_event):
        spinner = itertools.cycle([".", "..", "..."])
        while not stop_event.is_set():
            sys.stdout.write(f"\r{message} {next(spinner)}   ")
            sys.stdout.flush()
            time.sleep(0.5)
        sys.stdout.write("\r" + " " * 50 + "\r")

    def convert(self):
        results = []
        names = []
        for source in self.sources:
            name = Path((str(source))).name
            names.append(name)
            stop_event = threading.Event()
            progress_thread = threading.Thread(target=self.show_progress, args=(f"Converting {name}", stop_event))
            progress_thread.start()
            try:
                results.append(self.doc_converter.convert(source))
            finally:
                stop_event.set()
                progress_thread.join()
            print(f"Finished converting {name}")

        if self.format == "qdrant":
            return self.convert_qdrant(results)
        if self.format == "markdown":
            return self.convert_markdown(results)
        if self.format == "json":
            return self.convert_json(results)
        if self.format == "milvus":
            self.convert_milvus(results)

    def convert_markdown(self, results):
        ctr = 0
        for ctr, result in enumerate(results):
            dirname = self.output + os.path.dirname(self.sources[ctr])
            os.makedirs(dirname, exist_ok=True)
            document = result.document
            document.save_as_markdown(os.path.join(dirname, f"{document.name}.md"))

    def convert_json(self, results):
        for ctr, result in enumerate(results):
            dirname = self.output + os.path.dirname(self.sources[ctr])
            os.makedirs(dirname, exist_ok=True)
            document = result.document
            document.save_as_json(os.path.join(dirname, f"{document.name}.json"))

    def walk(self, path):
        for root, dirs, files in os.walk(path, topdown=True):
            if len(files) == 0:
                continue
            for f in files:
                file = os.path.join(root, f)
                if os.path.isfile(file):
                    self.sources.append(file)

    def generate_hash(self, document: str) -> int:
        """Generate a unique int64 hash from the document text."""
        sha256_hash = hashlib.sha256(document.encode('utf-8')).hexdigest()
        uuid_val = uuid.UUID(sha256_hash[:32])
        # Convert to signed int64 (Milvus requires signed 64-bit)
        return uuid_val.int & ((1 << 63) - 1)


def load():
    # Dummy code to preload models
    client = qdrant_client.QdrantClient(":memory:")
    client.set_model(EMBED_MODEL)
    client.set_sparse_model(SPARSE_MODEL)
    converter = DocumentConverter()
    converter.initialize_pipeline(InputFormat.PDF)
    AutoTokenizer.from_pretrained(EMBED_MODEL)


parser = argparse.ArgumentParser(
    description="process source files into RAG vector database",
)
parser.add_argument("--debug", action="store_true", help="Output debug information")
parser.add_argument("output", nargs="?", help="Output directory")
parser.add_argument("sources", nargs="*", help="Source files")
parser.add_argument(
    "--format",
    default="qdrant",
    help="Output format for RAG Data",
    choices=["qdrant", "json", "markdown", "milvus"],
)
parser.add_argument(
    "--ocr",
    action='store_true',
    help="Enable embedded image text extraction from PDF (Increases RAM Usage significantly)",
)


def perror(*args, **kwargs):
    print(*args, file=sys.stderr, **kwargs)


def eprint(e, exit_code):
    perror("Error: " + str(e).strip("'\""))
    sys.exit(exit_code)


def setup_logging(args: argparse.Namespace):
    if args.debug:
        level = logging.DEBUG
    else:
        level = logging.WARNING
    logging.basicConfig(
        level=level,
        style="{",
        format="{asctime} {name} {levelname}: {message}",
        force=True,
    )


try:
    args = parser.parse_args()
    setup_logging(args)

    if args.output == "load":
        load()
    else:
        converter = Converter(args)
        converter.convert()
except docling.exceptions.ConversionError as e:
    eprint(e, 1)
except FileNotFoundError as e:
    eprint(e, errno.ENOENT)
except ValueError as e:
    eprint(e, errno.EINVAL)
except KeyboardInterrupt:
    pass
